import * as ort from 'onnxruntime-web';
import * as math from 'mathjs';

// This function creates a 2D array of size boardSize x boardSize filled with zeros.
// It represents a blank image of the game board.
const getBlankImage = (boardSize) => {
    let blankImage = Array.from({ length: boardSize }, () => Array(boardSize).fill(0));
    return blankImage;
}

// This function takes an image (2D array) and a position of the apple.
// It sets the value at the apple's position in the image to 10.
const drawApple = async (image, applePosition) => {
    image[applePosition[1]][applePosition[0]] = 10;
}

// This function takes an image (2D array), positions of the snake, encoded moves and a flag indicating if the game is done.
// It sets the values at the snake's positions in the image according to the encoded moves.
// If the game is not done, it sets the value at the head of the snake to 1.
// Encoded moves are as follows:
// 1 - head, 2 - body Left, 3 - body Up, 4 - body Down, 5 - body Right, 6 - tail Left, 7 - tail Up, 8 - tail Down, 9 - tail Right
const drawSnake = (image, snakePositions, encodedMoves, done=false) => {
    let snakeLength = encodedMoves.length;
    for (let i = 0; i < snakeLength ; i++) {
        image[snakePositions[i + 1][1]][snakePositions[i + 1][0]] = encodedMoves[i];
    }
    if (!done) {
        image[snakePositions[0][1]][snakePositions[0][0]] = 1;
    }
}

// This function takes positions of the snake and returns encoded moves.
// It calculates the difference between consecutive positions, multiplies it by a matrix, and adjusts the values to be in the range 0-6.
// The last value is increased by 4.
// Encoded moves are as follows:
// 1 - head, 2 - body Left, 3 - body Up, 4 - body Down, 5 - body Right, 6 - tail Left, 7 - tail Up, 8 - tail Down, 9 - tail Right
const getEncodedMoves = (snakePositions) => {
    let snakeBody = math.subtract(snakePositions.slice(1), snakePositions.slice(0, -1));
    snakeBody = math.multiply(snakeBody, [[2],[3]])
    snakeBody = snakeBody.map((x) => x[0] < 0 ? x[0] + 7 : x[0]);
    snakeBody[snakeBody.length - 1] += 4;
    return snakeBody;
}

// This function takes the size of the board and returns the default initial positions of the snake.
// The snake starts at the center of the board and extends to the left.
const getStartSnake = (boardSize) => {
    let snakePositions = [[Math.floor(boardSize / 2), Math.floor(boardSize / 2)]];
    snakePositions.push([snakePositions[0][0] - 1, snakePositions[0][1]]);
    snakePositions.push([snakePositions[0][0] - 2, snakePositions[0][1]]);
    return snakePositions;
}

// This function generates a new apple position on the board.
// If user selection is enabled for the apple, it waits until the user selects a valid position.
// If the user does not select the apple or selects an invalid position, it generates a random position.
const getNewApple = async (boardSize, snakePositions) => {
    let applePosition = null;
    if (userSelectsApple) {
        userSelectedApplePosition = null;
        while ((applePosition === null || snakePositions.some(position => position[0] === applePosition[0] && position[1] === applePosition[1])) && userSelectsApple) {
            applePosition = userSelectedApplePosition;
            await sleep(10);
        }
    }
    if (!userSelectsApple || applePosition === null) {
        applePosition = [Math.floor(Math.random() * boardSize), Math.floor(Math.random() * boardSize)];
        while (snakePositions.some(position => position[0] === applePosition[0] && position[1] === applePosition[1])) {
            applePosition = [Math.floor(Math.random() * boardSize), Math.floor(Math.random() * boardSize)];
        }
    }
    return applePosition;
}

// This function moves the snake according to the given move and updates the game state.
// It checks if the game is done (if the snake hits the wall or itself).
// If the snake eats an apple, it generates a new apple and does not remove the last part of the snake.
// If the snake does not eat an apple, it removes the last part of the snake.
const stepSnake = async (snakePositions, encodedMoves, applePosition, boardSize, move) => {
    let snakeHead = snakePositions[0].slice();
    let done = false;
    let ateApple = false;
    switch (move) {
        case 0:
            // left
            snakeHead = [snakeHead[0] - 1, snakeHead[1]];
            encodedMoves.unshift(2);
            break;
        case 1:
            // right
            snakeHead = [snakeHead[0] + 1, snakeHead[1]];
            encodedMoves.unshift(5);
            break;
        case 2:
            // down
            snakeHead = [snakeHead[0], snakeHead[1] + 1];
            encodedMoves.unshift(4);
            break;
        case 3:
            // up
            snakeHead = [snakeHead[0], snakeHead[1] - 1];
            encodedMoves.unshift(3);
            break;
    }
    if (snakeHead[0] < 0 || snakeHead[0] >= boardSize || snakeHead[1] < 0 || snakeHead[1] >= boardSize) {
        done = true;
    }
    if (snakePositions.slice(1, snakePositions.length - 1).some(position => position[0] === snakeHead[0] && position[1] === snakeHead[1])) {
        done = true;
    }
    snakePositions.unshift(snakeHead);
    if (snakeHead[0] === applePosition[0] && snakeHead[1] === applePosition[1]) {
        if (userSelectsApple){
            drawImage(snakePositions, applePosition, boardSize, ctx, false);
        }
        ateApple = true;
        if (snakePositions.length === boardSize * boardSize) {
            done = true;
        }else {
            applePosition = await getNewApple(boardSize, snakePositions);
        }
    } else {
        snakePositions.pop();
        encodedMoves.pop();
        encodedMoves[encodedMoves.length - 1] += 4;
    }
    return [snakePositions, encodedMoves, applePosition, done, ateApple];
}

// This function returns a list of safe moves.
// A move is safe if it does not cause the snake to hit the wall or itself.
// The list is of length 4 and the values are 0 or 1.
// 0 - not safe, 1 - safe
// The order of the moves is as follows:
// 0 - left, 1 - right, 2 - up, 3 - down
const getSafeMoves = (snakePositions, boardSize) => {
    let safeMoves = [1, 1, 1, 1];
    let snakeHead = snakePositions[0].slice();
    let bodyBesidesTail = snakePositions.slice(1, -1);
    let moves = [[snakeHead[0] - 1, snakeHead[1]], [snakeHead[0] + 1, snakeHead[1]], [snakeHead[0], snakeHead[1] - 1], [snakeHead[0], snakeHead[1] + 1]];
    if (snakeHead[0] === 0) {
        safeMoves[0] = 0;
    }
    if (snakeHead[0] === boardSize - 1) {
        safeMoves[1] = 0;
    }
    if (snakeHead[1] === 0) {
        safeMoves[2] = 0;
    }
    if (snakeHead[1] === boardSize - 1) {
        safeMoves[3] = 0;
    }
    for (let i = 0; i < 4; i++) {
        if (bodyBesidesTail.some(position => position[0] === moves[i][0] && position[1] === moves[i][1])) {
            safeMoves[i] = 0;
        }
    }
    return safeMoves;
}

// This function updates the number of remaining moves.
// If resetMoves is true, it resets the number of remaining moves to the maximum value.
// If the number of remaining moves reaches 0, it sets done to true.
const getRemainingMoves = (numMoves, boardSize, resetMoves=false) => {
    let remainingMoves = numMoves;
    let done = false;
    if (resetMoves) {
        return [boardSize * boardSize + 5, done];
    }else {
        remainingMoves -= 1;
        if (remainingMoves <= 0) {
            done = true;
        }
        return [remainingMoves, done];
    }
}

// This function calculates the distance from the snake's head to the apple.
// It subtracts the apple's position from the snake's head position.
const getDistanceToApple = (snakePositions, applePosition) => {
    let snakeHead = snakePositions[0].slice();
    let dist = math.subtract(snakeHead, applePosition);
    return dist;
}

// This function updates the closest distance to the apple.
// It compares the current distance to the apple with the closest distance so far.
// If the current distance is smaller, it updates the closest distance.
const getClosestDistance = (snakePositions, applePosition, closestDistance) => {
    let currentDistance = getDistanceToApple(snakePositions, applePosition);
    if (math.abs(currentDistance[0]) < math.abs(closestDistance[0])) {
        closestDistance[0] = currentDistance[0];
    }
    if (math.abs(currentDistance[1]) < math.abs(closestDistance[1])) {
        closestDistance[1] = currentDistance[1];
    }
    return closestDistance;
}

// This function normalizes and flattens the game state.
// It divides the image by 10, adds the board size to the closest distance and divides by twice the board size,
// divides the remaining moves by the square of the board size plus 5, and concatenates all the values into a single array.
// This means it is an array of floats between 0 and 1.
const normAndFlatten = (image, closestDistance, remainingMoves, safeMoves, boardSize) => {
    let flatImage = image.flat();
    let normImage = math.divide(flatImage, 10);
    let normDistance = math.divide(math.add(closestDistance, boardSize), boardSize * 2);
    let normRemainingMoves = math.divide(remainingMoves, boardSize * boardSize + 5);
    let normSafeMoves = safeMoves;
    let normImageAndDistance = normImage.concat(normDistance);
    let normImageDistanceAndRemainingMoves = normImageAndDistance.concat(normRemainingMoves);
    let normImageDistanceRemainingMovesAndSafeMoves = normImageDistanceAndRemainingMoves.concat(normSafeMoves);
    return normImageDistanceRemainingMovesAndSafeMoves;
}

// This function draws a grid on the canvas.
// It calculates the width and height of each pixel and draws lines at these intervals.
const drawGrid = (boardSize, ctx) => {
    let imageWidth = ctx.canvas.width;
    let imageHeight = ctx.canvas.height;
    let pixelWidth = imageWidth / boardSize;
    let pixelHeight = imageHeight / boardSize;
    ctx.strokeStyle = 'white';
    ctx.lineWidth = 1;
    for (let i = 0; i < boardSize; i++) {
        ctx.beginPath();
        ctx.moveTo(i * pixelWidth, 0);
        ctx.lineTo(i * pixelWidth, imageHeight);
        ctx.stroke();
        ctx.beginPath();
        ctx.moveTo(0, i * pixelHeight);
        ctx.lineTo(imageWidth, i * pixelHeight);
        ctx.stroke();
    }
}

// This function draws the snake and the apple on the canvas.
// It clears the canvas, fills it with black, and draws the apple and the snake.
// If drawApple is false, it does not draw the apple. This is used when user is selecting the apple position.
const drawImage = (snakePositions, applePosition, boardSize, ctx, drawApple=true) => {
    let imageWidth = ctx.canvas.width;
    let imageHeight = ctx.canvas.height;
    let pixelWidth = imageWidth / boardSize;
    let pixelHeight = imageHeight / boardSize;
    ctx.clearRect(0, 0, imageWidth, imageHeight);
    ctx.fillStyle = 'black';
    ctx.fillRect(0, 0, imageWidth, imageHeight);
    if (drawApple){
        ctx.fillStyle = 'red';
        ctx.fillRect(applePosition[0] * pixelWidth, applePosition[1] * pixelHeight, pixelWidth, pixelHeight);
    }

    if (shouldDrawGrid){
        drawGrid(boardSize, ctx);
    }
    
    // Calculate the color step for each segment of the snake
    let colorStep = [
        (endColor[0] - startColor[0]) / snakePositions.length,
        (endColor[1] - startColor[1]) / snakePositions.length,
        (endColor[2] - startColor[2]) / snakePositions.length
    ];
    
    // Draw each segment of the snake with a gradient color
    for (let i = 0; i < snakePositions.length; i++) {
        let segmentColor = [
            startColor[0] + (colorStep[0] * i),
            startColor[1] + (colorStep[1] * i),
            startColor[2] + (colorStep[2] * i)
        ];
        ctx.fillStyle = `rgb(${segmentColor[0]}, ${segmentColor[1]}, ${segmentColor[2]})`;
        ctx.fillRect(snakePositions[i][0] * pixelWidth, snakePositions[i][1] * pixelHeight, pixelWidth, pixelHeight);
    }
}

// This function creates a promise that resolves after a specified number of milliseconds.
const sleep = (ms) => {
    return new Promise(resolve => setTimeout(resolve, ms));
}

// This function performs steps until the game is paused or has ended.
// It uses the model to predict the next move, updates the game state, and redraws the image.
// It returns the final game state or paused game state.
const doSteps = async (image, closestDistance, remainingMoves, safeMoves, snakePositions, encodedMoves, applePosition, boardSize, numSteps) => {
    let score = snakePositions.length - 3;
    // let paused = false;
    let ateApple = false;
    let done = false;
    let step = numSteps;
    let normObs = normAndFlatten(image, closestDistance, remainingMoves, safeMoves, boardSize);
    let input = new ort.Tensor('float32', normObs, [1, boardSize * boardSize + 7]);
    let feeds = {};
    feeds[model.inputNames[0]] = input;
    let output = await model.run(feeds);
    let outputData = await output[model.outputNames[0]].data;
    let move = Number(outputData[0]);
    while (!done) {
        ateApple = false;
        [snakePositions, encodedMoves, applePosition, done, ateApple] = await stepSnake(snakePositions, encodedMoves, applePosition, boardSize, move);
        image = getBlankImage(boardSize);
        drawSnake(image, snakePositions, encodedMoves, done);
        // if (userSelectsApple) {
        //     drawImage(snakePositions, applePosition, boardSize, ctx, false);
        // }
        drawApple(image, applePosition);
        drawImage(snakePositions, applePosition, boardSize, ctx);
        if (!done) {
            if (ateApple){
                closestDistance = getDistanceToApple(snakePositions, applePosition);
            }else {
                closestDistance = getClosestDistance(snakePositions, applePosition, closestDistance);
            }
            [remainingMoves, done] = getRemainingMoves(remainingMoves, boardSize, ateApple);
            safeMoves = getSafeMoves(snakePositions, boardSize);
            normObs = normAndFlatten(image, closestDistance, remainingMoves, safeMoves, boardSize);
            input = new ort.Tensor('float32', normObs, [1, boardSize * boardSize + 7]);
            feeds = {};
            feeds[model.inputNames[0]] = input;
            output = await model.run(feeds);
            outputData = await output[model.outputNames[0]].data;
            move = Number(outputData[0]);
            score = snakePositions.length - 3;
            step += 1;
        }
        await sleep(getDelay());
        if (paused === true) {
            break;
        }
    }
    return [image, closestDistance, remainingMoves, safeMoves, snakePositions, encodedMoves, applePosition, done, score, step];
}

// This function prepares the game state for a new game.
// It initializes the snake's position, the apple's position, and other game parameters.
// It returns the initial game state.
// If a snake position array is given, it uses that instead of generating a new snake.
const prepareSnake = async (boardSize, ctx, snakePosArray=null) => {
    let snakePositions;
    if (snakePosArray === null) {
        snakePositions = getStartSnake(boardSize);
    }else {
        snakePositions = snakePosArray;
    }
    let encodedMoves = getEncodedMoves(snakePositions);
    let image = getBlankImage(boardSize);
    drawSnake(image, snakePositions, encodedMoves);
    let applePosition;
    if (userSelectsApple) {
        drawImage(snakePositions, applePosition, boardSize, ctx, false);
    }
    applePosition = await getNewApple(boardSize, snakePositions);
    drawApple(image, applePosition);
    let closestDistance = getDistanceToApple(snakePositions, applePosition);
    let remainingMoves = getRemainingMoves(0, boardSize, true)[0];
    let safeMoves = getSafeMoves(snakePositions, boardSize);
    let done = false;
    let score = 0;
    let step = 0;
    drawImage(snakePositions, applePosition, boardSize, ctx);
    // normImageDistanceRemainingMovesAndSafeMoves
    return [image, closestDistance, remainingMoves, safeMoves, snakePositions, encodedMoves, applePosition, done, score, step];
}

// This function calculates the delay between steps based on the speed percent.
// The delay is a linear function of the speed percent, ranging from maxDelay at 0% to minDelay at 100%.
const getDelay = () => {
    let maxDelay = 250;
    let minDelay = 1;
    let delay = maxDelay - (maxDelay - minDelay) * speedPercent / 100;
    return delay;
}

// This function converts a hexadecimal color code to an RGB array.
const hexToRgb = (hex) => {
    // Remove the hash at the front
    hex = hex.substring(1);

    // Convert to RGB
    const r = parseInt(hex.substring(0, 2), 16);
    const g = parseInt(hex.substring(2, 4), 16);
    const b = parseInt(hex.substring(4, 6), 16);

    // Return as an array
    return [r, g, b];
}

// This function draws a line on the canvas.
// It concatenates the line to the full line, calculates the size of each pixel, and draws white rectangles at the positions of the points in the line.
const pixelateLine = (line) => {
    fullLine = fullLine.concat(line);
    let width = ctx.canvas.width;
    let height = ctx.canvas.height;
    let pixelSize = width / boardSize;
    ctx.fillStyle = 'white';
    line.forEach(point => {
        ctx.fillRect(point[0] * pixelSize, point[1] * pixelSize, pixelSize, pixelSize);
    });
}

// This function prepares the game state for a new game with a custom snake.
// It checks if the drawn snake is valid, and if it is, it initializes the game state with the drawn snake.
// It returns the initial game state, or null if the drawn snake is not valid.
const createCustomSnake = async (boardSize, ctx, drawnSnake) => {
    let uniqueDrawnSnake = [...new Set(drawnSnake.map(JSON.stringify))].map(JSON.parse);
    let valid = checkSnakeValidity(uniqueDrawnSnake, boardSize);
    if (!valid) {
        return null;
    }
    let snakePositions = uniqueDrawnSnake;
    return await prepareSnake(boardSize, ctx, snakePositions);
}

// This function checks if the snake's positions are valid.
// The snake's positions are valid if:
// - The snake has at least 3 segments.
// - Each segment of the snake is adjacent to the previous segment (the difference in x and y coordinates is at most 1).
// - All segments of the snake are within the board (the x and y coordinates are between 0 and boardSize - 1).
const checkSnakeValidity = (snakePositions, boardSize) => {
    let valid = true;
    if (snakePositions.length < 3) {
        valid = false;
    }
    if (snakePositions.some((position, index) => {
        if (index > 0) {
            const prevPosition = snakePositions[index - 1];
            const dx = Math.abs(position[0] - prevPosition[0]);
            const dy = Math.abs(position[1] - prevPosition[1]);
            return dx + dy > 1;
        }
        return false;
    })) {
        valid = false;
    }
    if (snakePositions.some(position => position[0] < 0 || position[0] >= boardSize || position[1] < 0 || position[1] >= boardSize)) {
        valid = false;
    }
    return valid;
}

// Function to load all models into cache
const loadModels = async () => {
    let loadPromises = [];
    for (let path of modelPaths) {
        loadPromises.push(ort.InferenceSession.create(path).then(model => {
            models[path] = model;
        }));
    }
    return Promise.all(loadPromises);
}

// Disables all UI controls
const disableAllControls = () => {
    playButton.disabled = true;
    pauseButton.disabled = true;
    resetButton.disabled = true;
    gridCheckbox.disabled = true;
    drawCustomSnake.disabled = true;
    headColorSelector.disabled = true;
    tailColorSelector.disabled = true;
    speedSlider.disabled = true;
    modelButtons.forEach(button => button.disabled = true);
}

// Enables all UI controls
const enableAllControls = () => {
    playButton.disabled = false;
    pauseButton.disabled = false;
    resetButton.disabled = false;
    gridCheckbox.disabled = false;
    drawCustomSnake.disabled = false;
    headColorSelector.disabled = false;
    tailColorSelector.disabled = false;
    speedSlider.disabled = false;
    modelButtons.forEach(button => button.disabled = false);
}
// Model buttons
// Canvas and Context
const canvas = document.querySelector('#snakeAiCanvas');
const ctx = canvas.getContext('2d');
ctx.imageSmoothingEnabled = false;
let shouldDrawGrid = false;
let isDrawing = false;
let line = [];
let fullLine = [];

// Array of model paths
const modelPaths = ['./Snake/6x6_model.onnx', './Snake/10x10_model.onnx', './Snake/20x20_model.onnx'];
const modelSizes = [6, 10, 20];

// Object to hold loaded models
let models = {};

// Load all models into cache
await loadModels();

// Model and Board Size
let model = models[modelPaths[1]];
let boardSize = modelSizes[1];

// Button Selection
const button6 = document.querySelector('#model6');
const button10 = document.querySelector('#model10');
const button20 = document.querySelector('#model20');
const modelButtons = document.querySelectorAll('#modelButtons button');

// Speed Control
let speedPercent = 50;
const speedSlider = document.querySelector('#speedSlider');

// Color Selection
let startColor = [0, 128, 0]; // Green
let endColor = [0, 128, 255]; // Blue
const headColorSelector = document.querySelector('#headColor');
const tailColorSelector = document.querySelector('#tailColor');

// Game State
let paused = true;
let pausedTemp = true;
let reset = false;
let done = false;
let score = 0;
let step = 0;
let userSelectedApplePosition = null;
let userSelectsApple = false;



// UI Buttons
const playButton = document.querySelector('#playButton');
const pauseButton = document.querySelector('#pauseButton');
const playPauseButtons = document.querySelectorAll('#playButton, #pauseButton');
const resetButton = document.querySelector('#resetButton');
const gridCheckbox = document.querySelector('#drawGrid');
const drawCustomSnake = document.querySelector('#drawCustomSnake');

// Event listeners for canvas
canvas.addEventListener('click', function(event) {
    const rect = canvas.getBoundingClientRect();
    const x = Math.floor((event.clientX - rect.left) / (rect.width / boardSize));
    const y = Math.floor((event.clientY - rect.top) / (rect.height / boardSize));
    userSelectedApplePosition = [x, y];
});

// Event listeners for model size buttons
button6.addEventListener('click', async () => {
    if (button6.classList.contains('selected')) {
        return;
    }
    pauseButton.click();
    resetButton.click();
    disableAllControls();
    model = models[modelPaths[0]];
    boardSize = modelSizes[0];
    [image, closestDistance, remainingMoves, safeMoves, snakePositions, encodedMoves, applePosition] = await prepareSnake(boardSize, ctx);
    enableAllControls();
});

button10.addEventListener('click', async () => {
    if (button10.classList.contains('selected')) {
        return;
    }
    pauseButton.click();
    resetButton.click();
    disableAllControls();
    model = models[modelPaths[1]];
    boardSize = modelSizes[1];
    [image, closestDistance, remainingMoves, safeMoves, snakePositions, encodedMoves, applePosition] = await prepareSnake(boardSize, ctx);
    enableAllControls();
});

button20.addEventListener('click', async () => {
    if (button20.classList.contains('selected')) {
        return;
    }
    pauseButton.click();
    resetButton.click();
    disableAllControls();
    model = models[modelPaths[2]];
    boardSize = modelSizes[2];
    [image, closestDistance, remainingMoves, safeMoves, snakePositions, encodedMoves, applePosition] = await prepareSnake(boardSize, ctx);
    enableAllControls();
});

// Add click event listener to each model button
modelButtons.forEach(button => {
    button.addEventListener('click', () => {
        // If button is already selected, do nothing
        if (button.classList.contains('selected')) {
            return;
        }

        // Remove 'selected' class from all buttons
        modelButtons.forEach(btn => btn.classList.remove('selected'));

        // Add 'selected' class to clicked button
        button.classList.add('selected');
    });
});

// Event listener for speed slider
speedSlider.addEventListener('input', () => {
    speedPercent = speedSlider.value;
});

// Event listeners for color selectors
headColorSelector.addEventListener('input', () => {
    startColor = hexToRgb(headColorSelector.value);
});

tailColorSelector.addEventListener('input', () => {
    endColor = hexToRgb(tailColorSelector.value);
});

// Event listener for play button
playButton.addEventListener('click', async () => {
    if (playButton.classList.contains('selected')) {
        return;
    }
    playButton.classList.add('selected');
    pauseButton.classList.remove('selected');
    paused = false;
    while (paused === false) {
        if (pausedTemp === false || reset === true) {
            [image, closestDistance, remainingMoves, safeMoves, snakePositions, encodedMoves, applePosition, done, score, step] = await prepareSnake(boardSize, ctx);
            reset = false;
        }
        pausedTemp = false;
        if (paused === false) {
            [image, closestDistance, remainingMoves, safeMoves, snakePositions, encodedMoves, applePosition, done, score, step] = await doSteps(image, closestDistance, remainingMoves, safeMoves, snakePositions, encodedMoves, applePosition, boardSize, step);
        }
        paused = pausedTemp;
    }
});

// Event listener for pause button
pauseButton.addEventListener('click', () => {
    if (pauseButton.classList.contains('selected')) {
        return;
    }
    pauseButton.classList.add('selected');
    playButton.classList.remove('selected');
    paused = true;
    pausedTemp = true;
});

// Add click event listener to each play/pause button
playPauseButtons.forEach(button => {
    button.addEventListener('click', () => {
        // If button is already selected, do nothing
        if (button.classList.contains('selected')) {
            return;
        }

        // Remove 'selected' class from all buttons
        playPauseButtons.forEach(btn => btn.classList.remove('selected'));

        // Add 'selected' class to clicked button
        button.classList.add('selected');
    });
});

// Event listener for reset button
resetButton.addEventListener('click', async () => {
    reset = true;
    pausedTemp = paused;
    paused = true;
});

// Event listener for grid checkbox
gridCheckbox.addEventListener('change', async () => {
    shouldDrawGrid = gridCheckbox.checked;
    if (shouldDrawGrid){
        drawGrid(boardSize, ctx);
    }
});

// Event listener for user selects apple checkbox
const userSelectsAppleCheckbox = document.querySelector('#userSelectsApple');
userSelectsAppleCheckbox.addEventListener('change', async () => {
    userSelectsApple = userSelectsAppleCheckbox.checked;
});


// Event listeners for canvas
// Drawing custom snake on desktop
canvas.addEventListener('mousedown', (event) => {
    if (!drawCustomSnake.classList.contains('selected')) {
        return;
    }
    isDrawing = true;
    line = [];
});

canvas.addEventListener('mousemove', (event) => {
    if (!drawCustomSnake.classList.contains('selected')) {
        return;
    }
    const rect = canvas.getBoundingClientRect();
    if (!isDrawing) return;
    let x = Math.floor((event.clientX - rect.left) / (rect.width / boardSize));
    let y = Math.floor((event.clientY - rect.top) / (rect.height / boardSize));
    line.push([x, y]);
});

canvas.addEventListener('mouseup', () => {
    if (!drawCustomSnake.classList.contains('selected')) {
        return;
    }
    isDrawing = false;
    pixelateLine(line);
});

// Drawing custom snake on mobile
canvas.addEventListener('touchstart', (event) => {
    if (!drawCustomSnake.classList.contains('selected')) {
        return;
    }
    event.preventDefault();
    isDrawing = true;
    line = [];
});

canvas.addEventListener('touchmove', (event) => {
    if (!drawCustomSnake.classList.contains('selected')) {
        return;
    }
    event.preventDefault();
    const rect = canvas.getBoundingClientRect();
    if (!isDrawing) return;
    let x = Math.floor((event.touches[0].clientX - rect.left) / (rect.width / boardSize));
    let y = Math.floor((event.touches[0].clientY - rect.top) / (rect.height / boardSize));
    line.push([x, y]);
});

canvas.addEventListener('touchend', () => {
    if (!drawCustomSnake.classList.contains('selected')) {
        return;
    }
    event.preventDefault();
    isDrawing = false;
    pixelateLine(line);
});

// Event listener for draw custom snake button
// Disables play button and user selects apple checkbox when drawing custom snake
let tempUserSelectsApple = userSelectsAppleCheckbox.checked;
drawCustomSnake.addEventListener('click', async () => {
    if (drawCustomSnake.classList.contains('selected')) {
        let customSnake = await createCustomSnake(boardSize, ctx, fullLine);
        if (customSnake !== null) {
            [image, closestDistance, remainingMoves, safeMoves, snakePositions, encodedMoves, applePosition, done, score, step] = await createCustomSnake(boardSize, ctx, fullLine);
        }else
        {
            [image, closestDistance, remainingMoves, safeMoves, snakePositions, encodedMoves, applePosition, done, score, step] = await prepareSnake(boardSize, ctx);
        }
        fullLine = [];
        drawCustomSnake.classList.remove('selected');
        userSelectsAppleCheckbox.checked = tempUserSelectsApple;
        playButton.disabled = false;
        userSelectsApple = tempUserSelectsApple;
        userSelectsAppleCheckbox.disabled = false;
    }else {
        drawCustomSnake.classList.add('selected');
        pauseButton.click();
        playButton.disabled = true;
        tempUserSelectsApple = userSelectsAppleCheckbox.checked;
        userSelectsAppleCheckbox.checked = false;
        userSelectsAppleCheckbox.disabled = true;
        userSelectsApple = false;
    }
});



// Snake and Apple
let image, closestDistance, remainingMoves, safeMoves, snakePositions, encodedMoves, applePosition;
[image, closestDistance, remainingMoves, safeMoves, snakePositions, encodedMoves, applePosition] = await prepareSnake(boardSize, ctx);